package top.cyuw.simplerpc.remoting.codec;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import lombok.extern.slf4j.Slf4j;
import top.cyuw.simplerpc.compress.Compress;
import top.cyuw.simplerpc.compress.CompressSelector;
import top.cyuw.simplerpc.constant.RpcConstants;
import top.cyuw.simplerpc.exception.RpcException;
import top.cyuw.simplerpc.extension.ExtensionLoader;
import top.cyuw.simplerpc.dto.RpcMessage;
import top.cyuw.simplerpc.dto.RpcRequest;
import top.cyuw.simplerpc.dto.RpcResponse;
import top.cyuw.simplerpc.serialize.Serializer;
import top.cyuw.simplerpc.serialize.SerializerSelector;

import java.util.Arrays;

/**
 * @author chen
 * @date 2023/3/13 10:10
 */
@Slf4j
public class RpcMessageDecoder extends LengthFieldBasedFrameDecoder {

    public RpcMessageDecoder() {
        super(RpcConstants.MAX_FRAME_LENGTH,
                RpcConstants.LENGTH_MAGIC_NUMBER + RpcConstants.LENGTH_VERSION, RpcConstants.LENGTH_FULL_LENGTH,
                -(RpcConstants.LENGTH_MAGIC_NUMBER + RpcConstants.LENGTH_VERSION + RpcConstants.LENGTH_FULL_LENGTH), 0);
    }

    @Override
    protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
        Object decoded = super.decode(ctx, in);
        if (decoded instanceof ByteBuf) {
            ByteBuf buf = (ByteBuf) decoded;
            if (buf.readableBytes() >= RpcConstants.LENGTH_HEAD) {
                try {
                    return doDecode(buf);
                } catch (Exception e) {
                    log.error("decode message failed: " + e.getMessage(), e);
                    throw e;
                } finally {
                    buf.release();
                }
            }
        }
        return decoded;
    }

    private Object doDecode(ByteBuf buf) {
        checkMagicNumber(buf);
        checkVersion(buf);
        int fullLength = buf.readInt();
        byte messageType = buf.readByte();
        int requestId = buf.readInt();
        RpcMessage rpcMessage = RpcMessage.builder()
                .messageType(messageType)
                .requestId(requestId).build();

        if (messageType == RpcConstants.MESSAGE_TYPE_RPC_REQUEST
                || messageType == RpcConstants.MESSAGE_TYPE_RPC_RESPONSE) {
            int bodyLength = fullLength - RpcConstants.LENGTH_HEAD;
            byte[] body = new byte[bodyLength];
            buf.readBytes(body);

            String compressName = CompressSelector.getCompress();
            Compress compress = ExtensionLoader.of(Compress.class).getExtension(compressName);
            body = compress.decompress(body);

            String serializerName = SerializerSelector.getSerializer();
            Serializer serializer = ExtensionLoader.of(Serializer.class).getExtension(serializerName);
            if (messageType == RpcConstants.MESSAGE_TYPE_RPC_REQUEST) {
                rpcMessage.setBody(serializer.deserialize(body, RpcRequest.class));
            } else {
                rpcMessage.setBody(serializer.deserialize(body, RpcResponse.class));
            }
        }

        return rpcMessage;
    }

    private void checkVersion(ByteBuf buf) {
        byte version = buf.readByte();
        if (version != RpcConstants.VERSION) {
            throw new RpcException("It isn't compatible version " + version);
        }
    }

    private void checkMagicNumber(ByteBuf buf) {
        int len = RpcConstants.MAGIC_NUMBER.length;
        byte[] magicNumber = new byte[len];
        buf.readBytes(magicNumber);
        for (int i = 0; i < len; i++) {
            if (magicNumber[i] != RpcConstants.MAGIC_NUMBER[i]) {
                throw new RpcException("unknown magic number: " + Arrays.toString(magicNumber));
            }
        }
    }
}
